import torch
import torch.nn as nn
from iclr23code import Neuron
from math import sqrt


def reset_model_state(slope=None, t=None):
    if slope is not None:
        if isinstance(slope, torch.Tensor):
            slope = slope.item()
        for layer in Neuron.instance:
            layer.slope = torch.tensor(slope, dtype=torch.float, device=layer.slope.device)
    Neuron.reset_state(t)


# ================================== Some BatchNorm layer ======================================
class Straight(nn.Module):
    def __init__(self):
        super(Straight, self).__init__()

    def forward(self, x):
        return x


def straight(channel=0, threshold=1.0, scale=1.0):
    return Straight()


def batch_norm2d(channel, threshold=1.0, scale=1.0):
    return nn.BatchNorm2d(channel)


class TdBN(nn.Module):
    def __init__(self, channel, threshold=1.0, scale=1.0):
        """
        Can compatible LIF_S, just need to update threshold after each optimizer.step
        """
        super(TdBN, self).__init__()
        self.tdbn = 1
        self.bn_layer = nn.BatchNorm2d(channel)
        self.scale = scale
        self.threshold = threshold

    def forward(self, x):
        return self.bn_layer(x) * self.scale * self.threshold


def td_bn(channel, threshold=1.0, scale=1.0):
    return TdBN(channel, threshold, scale)


class SewBN(nn.Module):
    def __init__(self, channel, threshold=1.0, scale=1.0):
        """
        Can compatible LIF_S, just need to update threshold after each optimizer.step
        """
        super(SewBN, self).__init__()
        self.tdbn = 1
        self.bn_layer = nn.BatchNorm2d(channel)

    def forward(self, x):
        return self.bn_layer(x)


def sew_bn(channel, threshold=1.0, scale=1.0):
    return SewBN(channel, threshold, scale)


# ============================== End of Some BatchNorm layer =====================================

class BaseModel(nn.Module):
    """
    Define some shared function
    """

    def __init__(self):
        super(BaseModel, self).__init__()
        Neuron.clear_neuron()

    def init_bias(self):
        for layer in self.modules():
            if isinstance(layer, (nn.Conv2d, nn.Linear)):
                if layer.bias is not None:
                    layer.bias.data.fill_(0)

    def init_weight_uniform(self):
        for layer in self.modules():
            if isinstance(layer, (nn.Conv2d, nn.Linear)):
                layer.weight.data.uniform_(0, 1)
                if layer.bias is not None:
                    layer.bias.data.zero_()

    def init_weight_kaiming(self):
        for layer in self.modules():
            if isinstance(layer, nn.Conv2d):
                n = layer.kernel_size[0] * layer.kernel_size[1] * layer.in_channels
                layer.weight.data.normal_(0, sqrt(2. / n))
                if layer.bias is not None:
                    layer.bias.data.zero_()
            elif isinstance(layer, nn.Linear):
                layer.weight.data.normal_(0, sqrt(2. / layer.in_features))
                if layer.bias is not None:
                    layer.bias.data.zero_()
            elif isinstance(layer, nn.BatchNorm2d):
                layer.weight.data.fill_(1)
                layer.bias.data.zero_()

    def zero_basic_blocks(self):
        for layer in self.modules():
            if isinstance(layer, TdBlock):
                if isinstance(layer.residual[3], TdBN):
                    torch.nn.init.constant_(layer.residual[3].bn_layer.weight, 0.)
                else:
                    torch.nn.init.constant_(layer.residual[3].weight, 0.)

    def zero_sew_blocks(self):
        for layer in self.modules():
            if isinstance(layer, SewBlock):
                if isinstance(layer.residual[3].bn, TdBN):
                    torch.nn.init.constant_(layer.residual[3].bn.bn_layer.weight, 0.)
                else:
                    torch.nn.init.constant_(layer.residual[3].bn.weight, 0.)


# ==================================== Some ResNet Blocks ====================================
class SewBlock(nn.Module):
    def __init__(self, in_channel, out_channel, stride, neuron, t, bn, spike_func=None, slope=1.0,
                 threshold=1.0, weak_mem=0.5, reset_mechanism='zero'):
        super(SewBlock, self).__init__()
        self.residual = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, stride, 1, bias=False),
            neuron(t, bn(out_channel, threshold), spike_func, slope, threshold, weak_mem,
                   reset_mechanism),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1, bias=False),
            neuron(t, bn(out_channel, threshold), spike_func, slope, threshold, weak_mem,
                   reset_mechanism),
        )
        if (in_channel != out_channel) or (stride != 1):
            self.identity = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, 1, stride, bias=False),
                neuron(t, bn(out_channel, threshold), spike_func, slope, threshold, weak_mem,
                       reset_mechanism),
            )
        else:
            self.identity = nn.Sequential()

    def forward(self, x):
        out = self.residual(x) + self.identity(x)
        return out


class TdBlock(nn.Module):
    def __init__(self, in_channel, out_channel, stride, neuron, t, bn, spike_func=None, slope=1.0,
                 threshold=1.0, weak_mem=0.5, reset_mechanism='zero'):
        super(TdBlock, self).__init__()
        self.residual = nn.Sequential(
            nn.Conv2d(in_channel, out_channel, 3, stride, 1, bias=False),
            neuron(t, bn(out_channel, threshold), spike_func, slope, threshold, weak_mem,
                   reset_mechanism),
            nn.Conv2d(out_channel, out_channel, 3, 1, 1, bias=False),
            bn(out_channel, threshold, sqrt(0.5)),
        )
        if (in_channel != out_channel) or (stride != 1):
            self.identity = nn.Sequential(
                nn.Conv2d(in_channel, out_channel, 1, stride, bias=False),
                bn(out_channel, threshold, sqrt(0.5)),
            )
        else:
            self.identity = nn.Sequential()

        self.lif = neuron(t, straight(), spike_func, slope, threshold, weak_mem, reset_mechanism)

    def forward(self, x):
        out = self.residual(x) + self.identity(x)
        out = self.lif(out)
        return out


# ================================== End of Some ResNet Blocks ==============================

class BaseResNet(BaseModel):
    def __init__(self):
        super(BaseModel, self).__init__()

    @staticmethod
    def _make_layer(in_channel, out_channel, stride, block_type, block_num, neuron, t, bn,
                    spike_func=None, slope=1.0, threshold=1.0, weak_mem=0.5,
                    reset_mechanism='zero'):
        if block_num <= 0:
            raise ValueError("`ResNet` only support positive `num_block`.")
        layers = list()
        layers.append(block_type(in_channel, out_channel, stride, neuron, t, bn, spike_func,
                                 slope, threshold, weak_mem, reset_mechanism))
        for _ in range(block_num - 1):
            layers.append(block_type(out_channel, out_channel, 1, neuron, t, bn, spike_func,
                                     slope, threshold, weak_mem, reset_mechanism))
        return nn.Sequential(*layers)


# ====================================== Model Used in Paper =================================
class AvgVgg16(BaseModel):
    def __init__(self, neuron, t, bn, spike_func=None, slope=1.0, threshold=1.0, weak_mem=1.0,
                 reset_mechanism='zero', dataset='MNIST'):
        super(AvgVgg16, self).__init__()

        self.feature = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1),
            neuron(t, bn(64, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(64, 64, 3, 1, 1),
            neuron(t, bn(64, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.AvgPool2d(2),

            nn.Conv2d(64, 128, 3, 1, 1),
            neuron(t, bn(128, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(128, 128, 3, 1, 1),
            neuron(t, bn(128, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.AvgPool2d(2),

            nn.Conv2d(128, 256, 3, 1, 1),
            neuron(t, bn(256, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(256, 256, 3, 1, 1),
            neuron(t, bn(256, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(256, 256, 3, 1, 1),
            neuron(t, bn(256, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.AvgPool2d(2),

            nn.Conv2d(256, 512, 3, 1, 1),
            neuron(t, bn(512, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(512, 512, 3, 1, 1),
            neuron(t, bn(512, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(512, 512, 3, 1, 1),
            neuron(t, bn(512, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.AvgPool2d(2),

            nn.Conv2d(512, 512, 3, 1, 1),
            neuron(t, bn(512, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(512, 512, 3, 1, 1),
            neuron(t, bn(512, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(512, 512, 3, 1, 1),
            neuron(t, bn(512, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
        )

        self.fc1 = nn.Sequential(
            nn.Linear(512 * 2 * 2, 4096),
            neuron(t, straight(), spike_func, slope, threshold, weak_mem, reset_mechanism),
        )
        self.fc2 = nn.Sequential(
            nn.Linear(4096, 4096),
            neuron(t, straight(), spike_func, slope, threshold, weak_mem, reset_mechanism),
        )

        if dataset == 'CIFAR100':
            self.fc3 = nn.Linear(4096, 100)
        else:
            self.fc3 = nn.Linear(4096, 10)

    def forward(self, x):
        x = self.feature(x)
        x = self.fc1(x.reshape(x.size(0), -1))
        x = self.fc2(x)
        x = self.fc3(x)
        return x


class NBAvgVgg16(AvgVgg16):
    def __init__(self, neuron, t, bn, spike_func=None, slope=1.0, threshold=1.0, weak_mem=1.0,
                 reset_mechanism='zero', dataset='MNIST'):
        super(NBAvgVgg16, self).__init__(neuron, t, bn, spike_func, slope, threshold, weak_mem,
                                         reset_mechanism, dataset)
        self.feature[0].bias = None
        self.feature[2].bias = None
        self.feature[5].bias = None
        self.feature[7].bias = None
        self.feature[10].bias = None
        self.feature[12].bias = None
        self.feature[14].bias = None
        self.feature[17].bias = None
        self.feature[19].bias = None
        self.feature[21].bias = None
        self.feature[24].bias = None
        self.feature[26].bias = None
        self.feature[28].bias = None

        self.init_bias()


class TdResNet19C10(BaseResNet):
    def __init__(self, neuron, t, bn, spike_func=None, slope=1.0, threshold=1.0, weak_mem=1.0,
                 reset_mechanism='zero', dataset='MNIST'):
        super(TdResNet19C10, self).__init__()

        self.pre_layer = nn.Sequential(
            nn.Conv2d(3, 64, 3, 1, 1, bias=False),
            neuron(t, bn(64, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
        )

        self.layer1 = self._make_layer(64, 128, 1, TdBlock, 3, neuron, t, bn, spike_func,
                                       slope, threshold, weak_mem, reset_mechanism)
        self.layer2 = self._make_layer(128, 256, 2, TdBlock, 3, neuron, t, bn, spike_func,
                                       slope, threshold, weak_mem, reset_mechanism)
        self.layer3 = self._make_layer(256, 512, 2, TdBlock, 2, neuron, t, bn, spike_func,
                                       slope, threshold, weak_mem, reset_mechanism)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(512, 256)
        self.lif1 = neuron(t, straight(), spike_func, slope, threshold, weak_mem, reset_mechanism)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = self.pre_layer(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.avgpool(x)
        x = self.fc1(x.reshape(x.size(0), -1))
        x = self.lif1(x)
        x = self.fc2(x)
        return x


class SewResNet34(BaseResNet):
    def __init__(self, neuron, t, bn, spike_func=None, slope=1.0, threshold=1.0, weak_mem=1.0,
                 reset_mechanism='zero', dataset='MNIST'):
        super(SewResNet34, self).__init__()

        self.pre_layer = nn.Sequential(
            nn.Conv2d(3, 64, 7, 2, 3, bias=False),
            neuron(t, bn(64, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
        )

        self.layer1 = self._make_layer(64, 64, 1, SewBlock, 3, neuron, t, bn, spike_func,
                                       slope, threshold, weak_mem, reset_mechanism)
        self.layer2 = self._make_layer(64, 128, 2, SewBlock, 4, neuron, t, bn, spike_func,
                                       slope, threshold, weak_mem, reset_mechanism)
        self.layer3 = self._make_layer(128, 256, 2, SewBlock, 6, neuron, t, bn, spike_func,
                                       slope, threshold, weak_mem, reset_mechanism)
        self.layer4 = self._make_layer(256, 512, 2, SewBlock, 3, neuron, t, bn, spike_func,
                                       slope, threshold, weak_mem, reset_mechanism)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, 1000)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x):
        x = self.pre_layer(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = self.fc(x.reshape(x.size(0), -1))
        return x


class DspikeModel(BaseResNet):
    def __init__(self, neuron, t, bn, spike_func=None, slope=1.0, threshold=1.0, weak_mem=1.0,
                 reset_mechanism='zero', dataset='MNIST'):
        super(DspikeModel, self).__init__()

        self.pre_layer = nn.Sequential(
            nn.Conv2d(2, 64, 3, 1, 1, bias=False),
            neuron(t, bn(64, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(64, 64, 3, 1, 1, bias=False),
            neuron(t, bn(64, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(64, 64, 3, 2, 1, bias=False),
            neuron(t, bn(64, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
        )

        self.layer1 = self._make_layer(64, 64, 1, TdBlock, 2, neuron, t, bn, spike_func,
                                       slope, threshold, weak_mem, reset_mechanism)
        self.layer2 = self._make_layer(64, 128, 2, TdBlock, 2, neuron, t, bn, spike_func,
                                       slope, threshold, weak_mem, reset_mechanism)
        self.layer3 = self._make_layer(128, 256, 2, TdBlock, 2, neuron, t, bn, spike_func,
                                       slope, threshold, weak_mem, reset_mechanism)
        self.layer4 = self._make_layer(256, 512, 2, TdBlock, 2, neuron, t, bn, spike_func,
                                       slope, threshold, weak_mem, reset_mechanism)
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(512, 10)
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                m.weight.data.normal_(0, sqrt(2. / n))
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.fill_(1)
                m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                n = m.weight.size(1)
                m.weight.data.normal_(0, 1.0 / float(n))
                m.bias.data.zero_()

    def forward(self, x):
        x = self.pre_layer(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.avgpool(x)
        x = self.fc(x.reshape(x.size(0), -1))
        return x

class VGG11(nn.Module):
    def __init__(self, neuron, t, bn, spike_func=None, slope=1.0, threshold=1.0, weak_mem=1.0,
                 reset_mechanism='zero', dataset='MNIST'):
        super(VGG11, self).__init__()
        self.feature = nn.Sequential(
            nn.Conv2d(2, 64, 3, 1, 1),
            neuron(t, bn(64, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(64, 128, 3, 1, 1),
            neuron(t, bn(128, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.AvgPool2d(2),

            nn.Conv2d(128, 256, 3, 1, 1),
            neuron(t, bn(256, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(256, 256, 3, 1, 1),
            neuron(t, bn(256, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.AvgPool2d(2),

            nn.Conv2d(256, 512, 3, 1, 1),
            neuron(t, bn(512, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(512, 512, 3, 1, 1),
            neuron(t, bn(512, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.AvgPool2d(2),

            nn.Conv2d(512, 512, 3, 1, 1),
            neuron(t, bn(512, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.Conv2d(512, 512, 3, 1, 1),
            neuron(t, bn(512, threshold), spike_func, slope, threshold, weak_mem, reset_mechanism),
            nn.AvgPool2d(2),
        )
        W = int(48 / 2 / 2 / 2 / 2)
        # self.T = 4
        self.classifier = nn.Linear(512 * W * W, 10)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')

    def forward(self, x):
        x = self.feature(x)
        x = self.classifier(x.reshape(x.size(0), -1))
        return x

